{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.938 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import mindspore\n",
    "from mindspore.dataset import transforms\n",
    "\n",
    "from mindnlp.engine import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mindnlp.dataset import load_dataset\n",
    "\n",
    "imdb_ds = load_dataset('imdb', split=['train', 'test'])\n",
    "imdb_train = imdb_ds['train']\n",
    "imdb_test = imdb_ds['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "25000"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imdb_train.get_dataset_size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mindnlp.transformers import AutoTokenizer\n",
    "# tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')\n",
    "\n",
    "test_tokenized = tokenizer('hello')\n",
    "test_tokenized.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False):\n",
    "    is_ascend = mindspore.get_context('device_target') == 'Ascend'\n",
    "    def tokenize(text):\n",
    "        if is_ascend:\n",
    "            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)\n",
    "        else:\n",
    "            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)\n",
    "        return tokenized['input_ids'], tokenized['token_type_ids'], tokenized['attention_mask']\n",
    "\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(batch_size)\n",
    "\n",
    "    # map dataset\n",
    "    dataset = dataset.map(operations=[tokenize], input_columns=\"text\", output_columns=['input_ids', 'token_type_ids', 'attention_mask'])\n",
    "    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns=\"label\", output_columns=\"labels\")\n",
    "    # batch dataset\n",
    "    if is_ascend:\n",
    "        dataset = dataset.batch(batch_size)\n",
    "    else:\n",
    "        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                             'token_type_ids': (None, 0),\n",
    "                                                             'attention_mask': (None, 0)})\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split train dataset into train and valid datasets\n",
    "imdb_train, imdb_val = imdb_train.split([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train = process_dataset(imdb_train, tokenizer, shuffle=True)\n",
    "dataset_val = process_dataset(imdb_val, tokenizer)\n",
    "dataset_test = process_dataset(imdb_test, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Tensor(shape=[32, 256], dtype=Int64, value=\n",
       " [[ 101, 1398,  178 ...    0,    0,    0],\n",
       "  [ 101, 1188, 1437 ...    0,    0,    0],\n",
       "  [ 101, 5145, 2568 ... 1116, 1132,  102],\n",
       "  ...\n",
       "  [ 101, 3517,  125 ...    0,    0,    0],\n",
       "  [ 101,  146, 1541 ...    0,    0,    0],\n",
       "  [ 101, 1188, 1110 ...    0,    0,    0]]),\n",
       " Tensor(shape=[32, 256], dtype=Int64, value=\n",
       " [[0, 0, 0 ... 0, 0, 0],\n",
       "  [0, 0, 0 ... 0, 0, 0],\n",
       "  [0, 0, 0 ... 0, 0, 0],\n",
       "  ...\n",
       "  [0, 0, 0 ... 0, 0, 0],\n",
       "  [0, 0, 0 ... 0, 0, 0],\n",
       "  [0, 0, 0 ... 0, 0, 0]]),\n",
       " Tensor(shape=[32, 256], dtype=Int64, value=\n",
       " [[1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 1, 1, 1],\n",
       "  ...\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0]]),\n",
       " Tensor(shape=[32], dtype=Int32, value= [1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, \n",
       "  0, 0, 1, 0, 1, 0, 1, 1])]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(dataset_train.create_tuple_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import AutoModelForSequenceClassification\n",
    "\n",
    "# set bert config and define parameters for training\n",
    "model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.engine import TrainingArguments\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"bert_imdb_finetune\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True,\n",
    "    num_train_epochs=3.0,\n",
    "    learning_rate=2e-5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp import evaluate\n",
    "import numpy as np\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=dataset_train,\n",
    "    eval_dataset=dataset_val,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3b06a2f5d244b09a880b0a68c18982c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|                                                                                                | 0/1641 …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.3206, 'learning_rate': 2e-05, 'epoch': 1.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|                                                                                                 | 0/235 …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.2645550072193146, 'eval_accuracy': 0.89, 'eval_runtime': 21.6942, 'eval_samples_per_second': 10.832, 'eval_steps_per_second': 1.383, 'epoch': 1.0}\n",
      "{'loss': 0.1771, 'learning_rate': 2e-05, 'epoch': 2.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|                                                                                                 | 0/235 …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.28401243686676025, 'eval_accuracy': 0.9021333333333333, 'eval_runtime': 21.4019, 'eval_samples_per_second': 10.98, 'eval_steps_per_second': 1.402, 'epoch': 2.0}\n",
      "{'loss': 0.1088, 'learning_rate': 2e-05, 'epoch': 3.0}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|                                                                                                 | 0/235 …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.3973828852176666, 'eval_accuracy': 0.9024, 'eval_runtime': 21.3734, 'eval_samples_per_second': 10.995, 'eval_steps_per_second': 1.404, 'epoch': 3.0}\n",
      "{'train_runtime': 882.004, 'train_samples_per_second': 59.537, 'train_steps_per_second': 1.861, 'train_loss': 0.20214975555810458, 'epoch': 3.0}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1641, training_loss=0.20214975555810458, metrics={'train_runtime': 882.004, 'train_samples_per_second': 59.537, 'train_steps_per_second': 1.861, 'train_loss': 0.20214975555810458, 'epoch': 3.0})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b1c9c5dec1f4ee3b12cbf3019b8ba63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|                                                                                                 | 0/782 …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.25377440452575684,\n",
       " 'eval_accuracy': 0.89504,\n",
       " 'eval_runtime': 70.9269,\n",
       " 'eval_samples_per_second': 11.025,\n",
       " 'eval_steps_per_second': 1.382,\n",
       " 'epoch': 3.0}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(dataset_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
